import keras.activations
import tensorflow as tf
import tensorflow_probability.python.distributions as tfd
import time
import numpy as np

from keras.layers import *
from utils.logger import Logger



class SplitEmbedding(tf.keras.Model):

    def __init__(self, conv, component):
        super(SplitEmbedding, self).__init__()
        self.conv = conv
        self.component = component

    def call(self, inputs):
        x = self.conv(inputs)
        L = x.shape[-1]
        if self.component == 0:
            # return x[:, :L // 2]
            return x[:, :, :, :L // 2]
        else:
            # return x[:, L // 2:]
            return x[:, :, :, L // 2:]

class OperatorConvolutionsSmaller(tf.keras.Model):

    def __init__(self, dense=False, size=9, dropout=None):
        super(OperatorConvolutionsSmaller, self).__init__()
        self.size = size
        self.dense = dense
        self.conv1 = Conv2D(16, self.size)
        self.conv2 = Conv2D(32, self.size)
        self.conv3 = Conv2D(64, self.size)
        self.potential_dropout = Dropout(dropout) if dropout is not None else lambda x, t: x
        self.pool = MaxPool2D()
        self.relu = ReLU()
        if self.dense:
            self.dense = Dense(30)

    def call(self, inputs, training=None, mask=None):
        # print(training)
        x = self.conv1(inputs)
        x = self.pool(x)
        x = self.potential_dropout(x, training)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = self.potential_dropout(x, training)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.pool(x)
        x = self.potential_dropout(x, training)
        x = self.relu(x)
        if self.dense:
            x = self.dense(x)
            x = self.relu(x)
        return x

class ObjectDetectorRegressor(tf.keras.Model):

    def __init__(self, conv, N=2, size=5, simple=True):
        super(ObjectDetectorRegressor, self).__init__()
        self.N = N
        self.simple = simple
        self.modules = [conv]
        if self.simple:
            self.emb1 = SplitEmbedding(conv, 0)
            self.emb2 = SplitEmbedding(conv, 1)
            self.regressor = OperatorNetRegrSeparateSmaller(N=self.N, h1=48, h2=24)
        else:
            self.modules.append(Conv2D(64, (2 * size, int(size)), strides=(2 * size, int(size)), activation='relu'))
            self.modules.append(Conv2D(64, 1, activation='relu'))
            self.modules.append(Conv2D(self.N + 1, 1))
            self.modules.append(Reshape([-1, self.N + 1]))

            self.model = tf.keras.Sequential(self.modules)

    def call(self, image):
        if self.simple:
            emb1 = self.emb1.call(image)
            emb2 = self.emb2.call(image)
            if self.N > 2:
                left, right = self.regressor.call(emb1), self.regressor.call(emb2)
                return [left[:, :2], left[:, 2:], right[:, :2], right[:, 2:]]
            return [self.regressor.call(emb1), self.regressor.call(emb2)]
        x = self.model.call(image)
        n = x.shape[1]
        scores = tf.math.sigmoid(x[:, :, 0])
        mu, sigma = (tf.math.tanh(x[:, :, 1]) + 1) / 3, tf.math.exp(x[:, :, 2])
        boxcoords = []
        for i in range(n):
            boxcoords.append(tf.ones_like(tf.reduce_mean(mu, axis=-1)) * i / (2 * n))
        boxcoords = tf.stack(boxcoords, axis=-1)
        mu += boxcoords + 1 / 8
        params = tf.stack([mu, sigma], axis=-1)
        if self.N > 2:
            muy, sigmay = (tf.math.sigmoid(x[:, :, 3])), tf.math.exp(x[:, :, 4])
            paramsy = tf.stack([muy, sigmay], axis=-1)
            Plist, ParamList1, ParamList2 = tf.unstack(scores, axis=-1), tf.unstack(params, axis=-2), tf.unstack(paramsy, axis=-2)
            return [[tf.expand_dims(tf.expand_dims(Plist[i], axis=-1), axis=-1), ParamList1[i], ParamList2[i]] for i in range(n)]
        Plist, ParamList = tf.unstack(scores, axis=-1), tf.unstack(params, axis=-2)
        return [[tf.expand_dims(tf.expand_dims(Plist[i], axis=-1), axis=-1), ParamList[i]] for i in range(n)]

class ObjectDetectorClassifier(tf.keras.Model):

    def __init__(self, conv, N=10, h=24, simple=True, generalised=False):
        super(ObjectDetectorClassifier, self).__init__()
        self.N = N
        self.h = h
        self.simple = simple
        self.generalised = generalised
        if self.generalised:
            self.mult = 1
        else:
            self.mult = 2
        self.modules = [conv]
        if self.simple:
            self.modules.append(Conv2D(self.N, 5))
            self.modules.append(GlobalMaxPooling2D())
            self.modules.append(Softmax())
        else:
            self.modules.append(Flatten())
            self.modules.append(Dense(2 * self.h, activation='relu'))
            self.modules.append(Dense(self.h, activation='relu'))
            self.modules.append(Dense(self.N, activation='softmax'))

        self.model = tf.keras.Sequential(self.modules)
    def call(self, image, locationx, locationsy, training=None):
        mu, sigma = locationx[:, 0], locationx[:, 1]
        x1 = tf.math.maximum(mu - self.mult * sigma, 0.)
        x2 = tf.math.minimum(mu + self.mult * sigma, 1.)
        x1 = tf.cast(tf.matmul(tf.expand_dims(x1, -1), tf.ones([1, 128])) * 128, tf.int32)
        x2 = tf.cast(tf.matmul(tf.expand_dims(x2, -1), tf.ones([1, 128])) * 128, tf.int32)
        mu, sigma = locationsy[:, 0], locationsy[:, 1]
        y1 = tf.math.maximum(mu - self.mult * sigma, 0.)
        y2 = tf.math.minimum(mu + self.mult * sigma, 1.)
        y1 = tf.cast(tf.matmul(tf.expand_dims(y1, -1), tf.ones([1, 128])) * 128, tf.int32)
        y2 = tf.cast(tf.matmul(tf.expand_dims(y2, -1), tf.ones([1, 128])) * 128, tf.int32)

        mask_cols = []
        for i in range(128):
            mask_cols.append(tf.where(tf.logical_and(x1 <= i, i <= x2), 1., 0.))
        mask = tf.expand_dims(tf.stack(mask_cols, axis=-1), axis=-1)
        mask_rows = []
        for i in range(128):
            mask_rows.append(tf.where(tf.logical_and(y1 <= i, i <= y2), 1., 0.))
        masky = tf.expand_dims(tf.stack(mask_rows, axis=-2), axis=-1)
        masked_image = (image + 1) * mask * masky - 1

        x = self.model.call(masked_image)
        return x

# The generalised OD utilises soft masking
class GeneralisedObjectDetectorClassifier(tf.keras.Model):

    def __init__(self, conv, N=10, h=24, simple=True):
        super(GeneralisedObjectDetectorClassifier, self).__init__()
        self.N = N
        self.h = h
        self.simple = simple
        self.modules = [conv]

        self.indextensor = tf.constant([i / 128 for i in range(128)])

        if self.simple:
            self.modules.append(Conv2D(self.N, 5))
            self.modules.append(GlobalMaxPooling2D())
            self.modules.append(Softmax())
        else:
            self.modules.append(Flatten())
            self.modules.append(Dense(2 * self.h, activation='relu'))
            self.modules.append(Dense(self.h, activation='relu'))
            self.modules.append(Dense(self.N, activation='softmax'))

        self.model = tf.keras.Sequential(self.modules)
    def call(self, image, locationx, locationsy, training=None):
        mux, sigmax = locationx[:, 0], locationx[:, 1]
        mux, sigmax = tf.expand_dims(mux, axis=-1), tf.expand_dims(sigmax, axis=-1)
        muy, sigmay = locationsy[:, 0], locationsy[:, 1]
        muy, sigmay = tf.expand_dims(muy, axis=-1), tf.expand_dims(sigmay, axis=-1)


        distrx = tfd.GeneralizedNormal(loc=mux, scale=sigmax, power=8.)
        distry = tfd.GeneralizedNormal(loc=muy, scale=sigmay, power=8.)
        distrx_max = distrx.prob(mux)
        distry_max = distry.prob(muy)

        mask_cols = distrx.prob(self.indextensor) / distrx_max
        mask_cols = tf.matmul(tf.expand_dims(mask_cols, axis=-1), tf.ones([1, 128]))
        maskx = tf.expand_dims(tf.transpose(mask_cols, perm=[0, 2, 1]), axis=-1)

        mask_rows = distry.prob(self.indextensor) / distry_max
        mask_rows = tf.matmul(tf.expand_dims(mask_rows, axis=-1), tf.ones([1, 128]))
        masky = tf.expand_dims(mask_rows, axis=-1)

        masked_image = (image + 1) * maskx * masky - 1

        x = self.model.call(masked_image)
        return x
        
class OperatorNetRegrSeparateSmaller(tf.keras.Model):

    def __init__(self, N=1, h1=120, h2=80):
        super(OperatorNetRegrSeparateSmaller, self).__init__()
        self.N = N
        self.flatten = Flatten()
        self.dense1 = Dense(h1)
        self.dense2 = Dense(h2)
        self.dense3 = Dense(self.N)
        self.relu = ReLU()

    def call(self, inputs, training=None, mask=None):
        x = self.flatten(inputs)
        x = self.dense1(x)
        x = self.relu(x)
        # x = self.dense2(x)
        # x = self.relu(x)
        x = self.dense3(x)
        if self.N == 2:
            x = tf.stack([(tf.math.tanh(x[:, 0]) + 1) / 2, tf.math.exp(x[:, 1])], axis=1)
        elif self.N == 4:
            x = tf.stack([(tf.math.tanh(x[:, 0]) + 1) / 2, tf.math.exp(x[:, 1]),
                          (tf.math.tanh(x[:, 2]) + 1) / 2, tf.math.exp(x[:, 3])], axis=1)
        # x = tf.math.tanh(x)
        return x


class StructuredODBaseline(tf.keras.Model):

    def __init__(self, rpn, classifier, memoryset=None):
        super(StructuredODBaseline, self).__init__()
        self.rpn = rpn
        self.classifier = classifier
        self.memoryset = memoryset
        self.logger = Logger()
        self.counter = 0

        self.modules = []
        self.modules.append(Dense(128, activation='relu'))
        self.modules.append(Dense(96, activation='relu'))
        self.modules.append(Dense(64, activation='relu'))
        self.modules.append(Dense(19, activation='softmax'))

        self.logicnet = tf.keras.Sequential(self.modules)

    def call(self, images):
        leftregx, leftregy, rightregx, rightregy = self.rpn.call(images)
        C1 = self.classifier.call(images, leftregx, leftregy)
        C2 = self.classifier.call(images, rightregx, rightregy)
        x = tf.concat([C1, C2], axis=-1)
        x = self.logicnet.call(x)
        return x

    def grad(self, inputs, targets):
        with tf.GradientTape() as tape:
            output = self.call(inputs)
            y_pred = output
            y_true = []
            for i in targets:
                zero_v = np.zeros(19)
                zero_v[9 + int(i)] = 1.
                y_true.append(tf.constant(zero_v))

            y_true = tf.stack(y_true)
            loss_value = tf.keras.losses.CategoricalCrossentropy(from_logits=False)(y_true, y_pred)

            # Only if we want to allow box width regularisation too!
            MSE = tf.keras.losses.MeanSquaredError()
            leftregx, leftregy, rightregx, rightregy = self.rpn.call(inputs)

            boxwidth = tf.ones_like(leftregx[:, 1]) * 7 / 128

            MSEloss_value = MSE(leftregx[:, 1], boxwidth) + MSE(leftregy[:, 1], boxwidth)
            MSEloss_value += MSE(rightregx[:, 1], boxwidth) + MSE(rightregy[:, 1], boxwidth)


            if self.memoryset is not None:
                rd_id = np.random.randint(0, len(self.memoryset))
                mem_images = self.memoryset[rd_id][0][0]
                mem_x1, mem_x2, mem_y1, mem_y2 = self.memoryset[rd_id][0][3:]
                boxwidth = tf.ones_like(mem_x1) * 7 / 128

                # mem_x1 = tf.stack([mem_x1, boxwidth], axis=-1)
                # mem_y1 = tf.stack([mem_y1, boxwidth], axis=-1)
                # mem_x2 = tf.stack([mem_x2, boxwidth], axis=-1)
                # mem_y2 = tf.stack([mem_y2, boxwidth], axis=-1)

                leftregx, leftregy, rightregx, rightregy = self.rpn.call(mem_images)

                MSEloss_value += MSE(leftregx[:, 0], mem_x1) + MSE(leftregx[:, 1], boxwidth)
                MSEloss_value += MSE(leftregy[:, 0], mem_y1) + MSE(leftregy[:, 1], boxwidth)
                MSEloss_value += MSE(rightregx[:, 0], mem_x2) + MSE(rightregx[:, 1], boxwidth)
                MSEloss_value += MSE(rightregy[:, 0], mem_y2) + MSE(rightregy[:, 1], boxwidth)
                loss_value += 1e3 * MSEloss_value

        return loss_value, tape.gradient(loss_value, self.trainable_variables)

    def train(self, data, epochs, update_its=1, log_its=100, val_data=None,
              eval_fns=None, fn_args=None):
        """
        Trains all weights present in the model graph.
        :return:
        """
        for epoch in range(epochs):
            print("Epoch {}".format(epoch + 1))
            accumulated_loss = 0
            acc_eval_time = 0
            prev_iter_time = time.time()
            for x, y in data:
                for j in range(update_its):
                    prev_eval_time = time.time()
                    loss_val, grads = self.grad(x[0], x[1])
                    accumulated_loss += loss_val.numpy()
                    acc_eval_time += time.time() - prev_eval_time
                    self.optimiser.apply_gradients(zip(grads, self.trainable_variables))
                self.counter += 1
                if self.counter % log_its == 0:
                    update_time = time.time() - prev_iter_time
                    if val_data == None:
                        print(
                            "Iteration: ",
                            self.counter,
                            "\ts:%.4f" % (update_time),
                            "\tAverage Loss: ",
                            accumulated_loss / log_its
                        )
                        self.log(self.counter, accumulated_loss, acc_eval_time, update_time, log_iter=log_its)
                    else:
                        val_loss = 0
                        val_counter = 0
                        for x2, y2 in val_data:
                            add_val_loss, _ = self.grad(x2[0], x2[1])
                            val_loss += add_val_loss.numpy()
                            val_counter += 1
                        accs = []
                        if eval_fns is not None:
                            for id, fn in enumerate(eval_fns):
                                acc = fn(*fn_args[id])
                                self.logger.log("val_accs{}".format(id), self.counter, acc)
                                accs.append(acc)
                        print(
                            "Iteration: ",
                            self.counter,
                            "\ts:%.4f" % (update_time),
                            "\tAverage Loss: ",
                            accumulated_loss / log_its,
                            "\tValidation Loss: ",
                            val_loss / val_counter,
                            "\tValidation Accs: ",
                            accs
                        )
                        self.log(self.counter, accumulated_loss, acc_eval_time, update_time, log_iter=log_its)
                    accumulated_loss = 0
                    prev_iter_time = time.time()

    def log(
        self, counter, acc_loss, eval_timing, it_timing, snapshot_iter=None,
            log_iter=100, verbose=1, **kwargs
    ):
        if (
            "snapshot_name" in kwargs
            and snapshot_iter is not None
            and counter % snapshot_iter == 0
        ):
            filename = "{}_iter_{}.mdl".format(kwargs["snapshot_name"], counter)
            print("Writing snapshot to " + filename)
            self.save_state(filename)
        if verbose and counter % log_iter == 0:
            self.logger.log("time", counter, it_timing)
            self.logger.log("loss", counter, acc_loss / log_iter)
            self.logger.log("eval_time", counter, eval_timing / log_iter)

